package com.uyuni.rpc.client.consumer;

import com.uyuni.rpc.client.loadbalance.LoadBalanceStrategies;
import com.uyuni.rpc.common.bean.ChannelGroups;
import com.uyuni.rpc.common.bean.UnresolvedAddress;
import com.uyuni.rpc.common.loadbalance.LoadBalanceStrategy;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;

/**
 * @author BazingaLyn
 * @description 消费端的抽象类，这个类的意义：
 * 1)保存从注册中心获取到的每个服务的提供者的信息
 * 2)保存每一个服务的负载均衡的策略
 * @time 2016年9月1日
 * @modifytime
 */
public abstract class AbstractDefaultConsumer implements Consumer {

    /******key是服务名，Value该服务对应的提供的channel的信息集合************/
    private volatile static ConcurrentMap<String, CopyOnWriteArrayList<ChannelGroups>> groups = new ConcurrentHashMap<>();
    /***********某个服务提供者的地址对应的channelGroup*************/
    protected final ConcurrentMap<UnresolvedAddress, ChannelGroups> addressGroups = new ConcurrentHashMap<>();
    /*********************某个服务对应的负载均衡的策略***************/
    protected final ConcurrentHashMap<String, LoadBalanceStrategy> loadConcurrentHashMap = new ConcurrentHashMap<>();

    /**
     * 为某个服务增加一个ChannelGroups
     *
     * @param serviceName
     * @param group
     */
    public static boolean addIfAbsent(String serviceName, ChannelGroups group) {
        CopyOnWriteArrayList<ChannelGroups> groupList = groups.get(serviceName);
        if (groupList == null) {
            CopyOnWriteArrayList<ChannelGroups> newGroupList = new CopyOnWriteArrayList<>();
            groupList = groups.putIfAbsent(serviceName, newGroupList);
            if (groupList == null) {
                groupList = newGroupList;
            }
        }
        return groupList.addIfAbsent(group);
    }

    /**
     * 当某个group 失效或者下线的时候，将其冲value中移除
     *
     * @param serviceName
     * @param group
     */
    public static boolean removedIfAbsent(String serviceName, ChannelGroups group) {
        CopyOnWriteArrayList<ChannelGroups> groupList = groups.get(serviceName);
        return groupList != null && groupList.remove(group);
    }

    public static CopyOnWriteArrayList<ChannelGroups> getChannelGroupsByServiceName(String service) {
        return groups.get(service);
    }

    /**
     * 为服务设置负载均衡策略
     *
     * @param serviceName
     * @param loadBalanceStrategy
     */
    @Override
    public void setServiceLoadBalanceStrategy(String serviceName, LoadBalanceStrategy loadBalanceStrategy) {
        loadConcurrentHashMap.put(serviceName, loadBalanceStrategy);
    }


    public static ConcurrentMap<String, CopyOnWriteArrayList<ChannelGroups>> getGroups() {
        return groups;
    }

    @Override
    public ChannelGroups loadBalance(String serviceName, LoadBalanceStrategy directBalanceStrategy) {
        LoadBalanceStrategy balanceStrategy = loadConcurrentHashMap.get(serviceName);

        CopyOnWriteArrayList<ChannelGroups> list = groups.get(serviceName);
        if (balanceStrategy == null) {
            if (directBalanceStrategy == null) {
                balanceStrategy = LoadBalanceStrategy.WEIGHTINGRANDOM;
            } else {
                balanceStrategy = directBalanceStrategy;
            }
        }

        if (null == list || list.size() == 0) {
            return null;
        }
        switch (balanceStrategy) {
            case RANDOM:
                return LoadBalanceStrategies.RANDOMSTRATEGIES.select(list);
            case WEIGHTINGRANDOM:
                return LoadBalanceStrategies.WEIGHTRANDOMSTRATEGIES.select(list);
            case ROUNDROBIN:
                return LoadBalanceStrategies.ROUNDROBIN.select(list);
            default:
                break;
        }
        return null;
    }

}
